from ..trainer_videobase import VideoBaseTrainer
import torch.nn.functional as F
from typing import Optional
import os
import torch
from transformers.utils import WEIGHTS_NAME
import json

class VQVAETrainer(VideoBaseTrainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        model = model.module
        x = inputs.get("video")
        
        device_type = None
        if torch.cuda.is_available():
            device_type = "cuda"
        else:
            device_type = "npu"

        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            x = x / 2
            z = model.pre_vq_conv(model.encoder(x))
            vq_output = model.codebook(z)
            x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"]))
            recon_loss = F.mse_loss(x_recon, x) / 0.06
            commitment_loss = vq_output['commitment_loss']
            loss = recon_loss + commitment_loss
        return loss

